%--------------------------------------------------------------------------
%                  APYN MAIN IMPUTATION PROCESS 
%-------------------------------------------------------------------------- 
%
% Title: Semi-parametric Selection Models for Potentially Non-ignorable Attrition...
% in Panel Studies with Refreshment Samples
% Manuscript ID PA-2013-094
% Y.SI, J.REITER AND S.HILLYGUS
%-------------------------------------------------------------------------- 
%               Bayesian semi-parametric selection models
%                   joint for (X Y) & W depends on (X Y)
%                        APYN data analysis
%-------------------------------------------------------------------------- 
%%%% Author: Yajuan Si
%%%% Latest edit date: 03/30/2014
%-------------------------------------------------------------------------- 
clear;clc;close all;
tic
%%%%--------------------read data--------------------------------------%%%%
data_original = csvread('data/apyn_data.csv');
%sample indicator 1st
%attritrion indicator 2nd 
xindex=3:11; %X PARTYID1/ID12/REL33/age4/edu5/race6/sex7/inc8/marriage9
y1index=12:16; %Y1 FAV1-4 LV1 LV3 CND1 LV31
y2index=17:21; %Y2 FAV1-4 LV1 LV3 CND1 LV31
    
%%%%----------------data organizing------------------------------------%%%%
%missing mechanism: -3 system -2 don't have much to say -1 refused
 
%X
    baselinex = data_original(data_original(:,1)== 6, xindex); %panel X
    refx = data_original(data_original(:,1)== 2, xindex); %refreshment X
%Y1 Y2
    baseliney1 = data_original(data_original(:,1)==6, y1index);
    baseliney2 = data_original(data_original(:,1)==6, y2index);
    
    basesize = sum(data_original(:,2)>-2);
    attrition = data_original(data_original(:,2)>-2,460); %W
    refsize = sum(data_original(:,1)==2);
    
    refy2 = data_original(data_original(:,1)==2, y2index);
    refy1 = data_original(data_original(:,1)==2, y1index); %all are -3

     
%%%%-----------data size parameters------------------------------------%%%% 
     Np = basesize; Ncp=sum(attrition==0); Nip=Np-Ncp; Nr=refsize;
     
     N = Np+Nr; wholeindex=1:N; panelindex=1:Np; refindex=1:Nr;
     panelindex_1=1:Ncp; panelindex_2=Ncp+1:Np;
     
    q=length(xindex); p1=length(y1index); p2=length(y2index); p=q+p1+p2;
    
%%%%-----------transfer data form 1:d---------------------------------%%%%
   %X 
    samplexindex=1:q;
    zeroxindex = samplexindex(sum(baselinex==0)>0);
    for j=1:length(zeroxindex)
        baselinex(baselinex(:,zeroxindex(j))>=0, zeroxindex(j))=...
           baselinex(baselinex(:,zeroxindex(j))>=0, zeroxindex(j))+1; 
        refx(refx(:,zeroxindex(j))>=0, zeroxindex(j))=...
           refx(refx(:,zeroxindex(j))>=0, zeroxindex(j))+1; 
    end   
  %Y1
  sampley1index=1:p1;
    zeroy1index = sampley1index(sum(baseliney1==0)>0);
    for j=1:length(zeroy1index)
        baseliney1(baseliney1(:,zeroy1index(j))>=0, zeroy1index(j))=...
           baseliney1(baseliney1(:,zeroy1index(j))>=0, zeroy1index(j))+1; 
    end   
  %Y2
     sampley2index=1:p2;
    zeroy2index = sampley2index(sum(baseliney2==0)>0);
    for j=1:length(zeroy2index)
        baseliney2(baseliney2(:,zeroy2index(j))>=0, zeroy2index(j))=...
           baseliney2(baseliney2(:,zeroy2index(j))>=0, zeroy2index(j))+1; 
        refy2(refy2(:,zeroy2index(j))>=0, zeroy2index(j))=...
           refy2(refy2(:,zeroy2index(j))>=0, zeroy2index(j))+1; 
    end   
  
%%%%-------------------------initial data----------------------------%%%% 
   

d = max([baselinex, baseliney1, baseliney2]); %number of levels

all=[attrition, baselinex, baseliney1, baseliney2];
all2=sortrows(all,1);  % re-order data by W 1st-col in reverse order

%X-use observed marginal probabilities
X_0=[all2(:,2:q+1); refx];
mar_x_index = samplexindex(sum(X_0<0)>0);
mar_x= (X_0(:,mar_x_index) <0); X = X_0; 
for j=1:length(mar_x_index)
    prob_j=tabulate(X(mar_x(:,j)==0,mar_x_index(j)));
    xjmarsize = sum(mar_x(:,j));
    X(mar_x(:,j)==1,mar_x_index(j)) =randsample(1:d(mar_x_index(j)),xjmarsize,true,prob_j(:,3)/100);
end

%Y1-use observed marginal probabilities
Y1_0=[all2(:,q+2:q+p1+1); refy1]; 
mis_y1_index = sampley1index(sum(Y1_0<0)>0);
mis_y1= (Y1_0(:,mis_y1_index) <0); Y1 = Y1_0; 
for j=1:length(mis_y1_index)
    prob_j=tabulate(Y1(mis_y1(:,j)==0,mis_y1_index(j)));
    y1jmissize = sum(mis_y1(:,j));
    Y1(mis_y1(:,j)==1,mis_y1_index(j)) =randsample(1:d(q+mis_y1_index(j)),y1jmissize,true,prob_j(:,3)/100); 
end

%Y2-treat the missing values separately
Y2_0=[all2(:,q+p1+2:p+1); refy2]; 
Y2 = Y2_0; 
 mis_y2_index = sampley2index(sum(Y2_0<0)>0);
 mis_y2= (Y2_0(:,mis_y2_index) <0); 

%MAR item nonresponse in the panel with W=1 
panelmar_y2_index = sampley2index(sum(Y2_0(panelindex,:)<0 & Y2_0(panelindex,:) > -3)>0);
panelmar_y2 = Y2_0(1:Ncp,panelmar_y2_index) <0; 

for j=1:length(panelmar_y2_index)
    y2jmarsize=sum(panelmar_y2(:,j));
    panel_prob_j_w1=tabulate(Y2(panelmar_y2(:,j)==0,panelmar_y2_index(j)));
    Y2(panelmar_y2(:,j)==1,panelmar_y2_index(j))=...
        randsample(1:d(q+p1+panelmar_y2_index(j)), y2jmarsize, true, panel_prob_j_w1(:,3)/100);
end   
%NMAR attrition in the panel with W=0
panelnmar_y2_index = sampley2index(sum(Y2_0(panelindex,:)==-3)>0);
panelnmar_y2 = Y2_0(panelindex,panelnmar_y2_index)==-3;

for j=1:length(panelnmar_y2_index)
    
   y2jnmarsize =sum(panelnmar_y2(:,j)); 
   ref_prob_j=tabulate(refy2(mis_y2(Np+1:N,j)==0,panelnmar_y2_index(j))); 
   stay_prob_j=tabulate(Y2(mis_y2(1:Ncp,j)==0,panelnmar_y2_index(j))); 
   prob_j=(ref_prob_j(:,3)/100-stay_prob_j(:,3)/100*Ncp/Np)/(Nip/Np);
   Y2(panelnmar_y2(:,j)==1,panelnmar_y2_index(j)) =randsample(1:d(q+p1+panelnmar_y2_index(j)),y2jnmarsize,true,prob_j); 
end

%MAR item nonresponse in the refreshment sample
refmis_y2_index = sampley2index(sum(refy2<0)>0);
refmis_y2= refy2(:,refmis_y2_index) <0; 

refy2_ini=refy2;
for j=1:length(refmis_y2_index)
    y2jrefsize=sum(refmis_y2(:,j));
    ref_prob_j=tabulate(refy2(refmis_y2(:,j)==0,refmis_y2_index(j)));
    refy2_ini(refmis_y2(:,j)==1,refmis_y2_index(j))=...
        randsample(1:d(q+p1+refmis_y2_index(j)), y2jrefsize, true, ref_prob_j(:,3)/100);
end  
Y2(Np+1:N,:)=refy2_ini;

%initial data
data=[X Y1 Y2];

%not use ID/REL/PARTYID AS X
use_var=4:p; wx_data=data(:,use_var); x_d=d(use_var);
xp=sum(x_d)-(p-3); %dimension after dummy coding

del_index = cumsum(x_d);
index_dc=ones(xp,1);
temp=1;
for j=2:sum(x_d)
    if sum(j==del_index)==0
        temp=temp+1;
        index_dc(temp)=j;
    end    
end 
%X after dummy coding
x_dc = dummyvar(wx_data); w_x = x_dc(:,index_dc);

%initial values for beta
beta_w=glmfit(w_x(1:Np,:), [ones(Ncp,1); zeros(Nip,1)],'binomial','link','probit');
%intial W
W=[ones(Ncp,1)', zeros(Nip,1)', binornd(1,Ncp/Np, [Nr, 1])']';

%%
%%%%------------MCMC global parameters & initial setting---------------%%%%
Trun=100000; % size of MCMC sample 
burn =0; thin = 1; effsize = (Trun-burn)/thin ;
H = 30; %truncation #classes
zupdateprob = zeros(N,H);
cc_f = zeros(H,1);


%%%%-----------------MCMC output files---------------------------------%%%%
alpha_out = zeros(effsize,1);
beta_w_out = zeros(effsize,xp+1); cell_out=zeros(effsize,p);
pi_out = zeros(effsize,H);
kout = zeros(effsize,1);  %number of unique clusters
y1_out = zeros(effsize,p1);
y2_out = zeros(effsize,p2); w_out = zeros(effsize,1);

%%%%--------------output for multiple imputation-----------------------%%%%
m = 500; %no of completed datasets; large value to decrease effects of CR
cut = Trun - 100 * (0:1:m-1);  imp = 0;
impdata_out = zeros(N*m,p+2);repdata_out_1= zeros(N*m,p+2);

%%%%--------------------prior specification----------------------------%%%%
a0=1/4; b0=1/4; %for alpha
Sigma0 =1000; %assume prior for beta are 0 N(0,I(p+1)*Sigma0)
a_j = ones(max(d),1);  %for phi in dirichlet/beta distribution

%%%%------------Initial values for mixture components------------------%%%%

rows = repmat(1:p,N,1); rows = rows(:); cols = data(:);

lin_idx = sub2ind([p,max(d)],rows,cols); %category index

phi=zeros(H,p,max(d));      % cell probabilities

for h=1:H
    for j=1:p
        for l=1:d(j)
        phi(h,j,l)=sum(data(:,j)==l)/N;
        end
    end
end

mat=zeros(p,max(d));    % for updating phi
alpha =1;
nus = betarnd(1,alpha,[1 H]);  % stick-breaking random variables
nus(H)=1;
nu = nus.*cumprod([1 1-nus(:,1:H-1)]);     % category probabilities
nu=nu/sum(nu);
ph = randsample(1:H,N,true,nu); %latent class indicator
%%
%%%%------------------MCMC main process--------------------------------%%%%
for t=1:Trun
    
  % 1) -- update stick-breaking random variables -- %

  for h = 1:H-1
      cc_f(h) = nu(h)/nus(h); 
      if nus(h)~=1, cc_f(h+1:H) = nu(h+1:H)/(1-nus(h)); end
      nus(h) = betarnd(1 + sum(ph==h), alpha + sum(ph>h));
      nu(h) = cc_f(h)*nus(h); 
      if nus(h)==1, nu(h+1:H) = 0; else nu(h+1:H) = cc_f(h+1:H)*(1-nus(h)); end
      if nus(h)>1-1e-5, nus(h)=1-1e-5; end 
   end
   
   % 2)-- update alpha -- %
    nuss = 1 - nus(1:H-1); nuss(nuss < 1e-6) = 1e-6; 
    alpha = gamrnd(a0 + H-1, 1/(b0 - sum(log(nuss))));

   % 3)-- update ph: allocation to atoms -- %
    
      cols = data(:); 
      lin_idx = sub2ind([p,max(d)],rows,cols); 
       
        for h = 1:H
          
            phih = reshape(phi(h,:,:),p,max(d)); 
            tmpmatL = reshape(phih(lin_idx),[N,p]); 
            
            zupdateprob(:,h) = nu(h) * prod(tmpmatL,2);   
        
        end
        zupdateprob1 = bsxfun(@times,zupdateprob,1./(sum(zupdateprob,2)));
        mat1 = [zeros(N,1) cumsum(zupdateprob1,2)];
        rr = unifrnd(0,1,[N,1]);
        for l = 1:H
            ind = rr > mat1(:,l) & rr <= mat1(:,l+1); ph(ind) = l; 
        end
        
   
   % 4)-- update Phi -- %
    for h = 1:H
           
        nh = reshape(ph==h,[N 1]);
        
        for c = 1:max(d)
            mat(:,c) = (a_j(c) + sum(bsxfun(@times,(data==c),nh)))';
        end
            
        Lamh1 = gamrnd(mat,1); 
        for j=1:p
            Lamh = bsxfun(@times,Lamh1(j,1:d(j)), 1./sum(Lamh1(j,1:d(j))));
            phi(h,j,1:d(j)) = Lamh;
        end
    end
    

  
    % 5) -- update beta_w -- %

   wx_data=data(:,use_var);
   xx_dc = dummyvar(wx_data);    
   W_x = [ones(N,1) xx_dc(:,index_dc)];

    VV=inv(eye(xp+1)/Sigma0+W_x'* W_x);
    SS = (eye(xp+1)/Sigma0+W_x'* W_x)\W_x';  %(p+1)*N matrix 
    HM = W_x * SS; hdiag = diag(HM); wgh = hdiag./(1-hdiag);zt_v=1+wgh;

    zt=zeros(N,1);
    BB = (eye(xp+1)/Sigma0+W_x'* W_x)\( W_x'* zt);

    for t_inc=1:100

    % a) -- update zt -- %

        zt_old=zt;
        mi=W_x * BB;
        mi=mi-wgh .* (zt_old-mi);
        af=normcdf(zeros(N,1), mi, sqrt(zt_v));
        af(af==1 & W==1) = 0.99999;
        af(af==0 & W==0) = 0.00001;
        ui=unifrnd(0,af).*(W==0)+unifrnd(af,1).*(W==1);
        zt=norminv(ui,mi,sqrt(zt_v));
        BB = BB + SS * (zt-zt_old);
 
    % b) -- update beta_w -- %
        beta_w =reshape(mvnrnd(BB,VV),[xp+1,1]);

    end
         
   
    % 6)-- impute missing values of X -- %

    for j=1:length(mar_x_index)
        xjmarsize = sum(mar_x(:,j));
        xjindex= wholeindex(mar_x(:,j)==1);
        phi_x = reshape(phi(ph(mar_x(:,j)==1),mar_x_index(j),1:d(mar_x_index(j))),...
        [xjmarsize d(mar_x_index(j))]);
        mat_x = [zeros(xjmarsize,1) cumsum(phi_x,2)];
        r_x = unifrnd(0,1,[xjmarsize,1]);

        data_new=data;
         for l=1:d(mar_x_index(j))
            misind1 =r_x > mat_x(:,l) & r_x <= mat_x(:,l+1);
            data_new(xjindex(misind1),mar_x_index(j))=l; 
         end    
    
        data(Np+1:N,mar_x_index(j)) = data_new(Np+1:N,mar_x_index(j)); 
 
        panelxjmarsize=sum(mar_x(1:Np,j));
 
        wx_data_new=data_new(:,use_var);
        xxnew_dc = dummyvar(wx_data_new);
        xxjpanelmisindex = panelindex(mar_x(1:Np,j)==1);
        X_new_data=[ones(panelxjmarsize,1) xxnew_dc(xxjpanelmisindex,index_dc)];
        x_acp_w = (normcdf(X_new_data * beta_w)).^W(xxjpanelmisindex).*...
        ((1-normcdf(X_new_data * beta_w)).^(1-W(xxjpanelmisindex)));
        jind=(rand(panelxjmarsize,1) <= x_acp_w);
        data(xxjpanelmisindex(jind),mar_x_index(j))=data_new(xxjpanelmisindex(jind),mar_x_index(j)); 
    end
    
    % 7)-- impute missing values of Y1 -- %

     
    for j=1:length(mis_y1_index)
       
    y1jmissize = sum(mis_y1(:,j));
    y1jindex= wholeindex(mis_y1(:,j)==1);
    jdata = q+mis_y1_index(j);
    
    phi_y1 = reshape(phi(ph(y1jindex),jdata,1:d(jdata)), [y1jmissize d(jdata)]);
         
    mat_y1 = [zeros(y1jmissize,1) cumsum(phi_y1,2)];
    r_y1 = unifrnd(0,1,[y1jmissize,1]);
    
         
     data_new=data;
    
         for l=1:d(jdata)
             misind1 =r_y1 > mat_y1(:,l) & r_y1 <= mat_y1(:,l+1);
             data_new(y1jindex(misind1),jdata)=l; 
         end   
    

    data(Np+1:N,jdata) = data_new(Np+1:N,jdata); 
    panely1jmarsize=sum(mis_y1(1:Np,j));
  
    wx_data_new=data_new(:,use_var);
    xxnew_dc = dummyvar(wx_data_new);
    y1jpanelmisindex = panelindex(mis_y1(1:Np,j)==1);
    X_new_data=[ones(panely1jmarsize,1) xxnew_dc(y1jpanelmisindex,index_dc)];
    y1_acp_w = (normcdf(X_new_data * beta_w)).^W(y1jpanelmisindex).*...
    ((1-normcdf(X_new_data * beta_w)).^(1-W(y1jpanelmisindex)));
    jind=(rand(panely1jmarsize,1) <= y1_acp_w);

        
    data(y1jpanelmisindex(jind),jdata)=data_new(y1jpanelmisindex(jind),jdata); 
    end   
     
    % 8)-- impute missing values of Y2 due to MAR in the refreshment -- %
   
     
    for j=1:length(refmis_y2_index)
         
    refy2jmissize = sum(refmis_y2(:,j));
    refy2jindex= refindex(refmis_y2(:,j)==1)+Np;
    jdata = q+p1+refmis_y2_index(j);
         
    phi_y2 = reshape(phi(ph(refy2jindex),jdata,1:d(jdata)), [refy2jmissize d(jdata)]);
        

    mat_y2 = [zeros(refy2jmissize,1) cumsum(phi_y2,2)];
        
    r_y2 = unifrnd(0,1,[refy2jmissize,1]);
        for l=1:d(jdata)
         misind2 =r_y2 > mat_y2(:,l) & r_y2 <= mat_y2(:,l+1);
         data(refy2jindex(misind2),jdata)=l; 
        end
            
    end 

     
    % 9)-- impute missing values of Y2 due to MAR IN THE PANEL /W=1 -- %

     
    for j=1:length(panelmar_y2_index)
         
    panely2jmarsize = sum(panelmar_y2(:,j));
    panely2jmarindex= panelindex_1(panelmar_y2(:,j)==1);
    jdata = q+p1+panelmar_y2_index(j);
         
    phi_y2 = reshape(phi(ph(panely2jmarindex),jdata,1:d(jdata)), [panely2jmarsize d(jdata)]);
    data_new=data;

    mat_y2 = [zeros(panely2jmarsize,1) cumsum(phi_y2,2)];
        
    r_y2 = unifrnd(0,1,[panely2jmarsize,1]);
        for l=1:d(jdata)
            misind2 =r_y2 > mat_y2(:,l) & r_y2 <= mat_y2(:,l+1);
            data_new(panely2jmarindex(misind2),jdata)=l; 
        end
    
  
    wx_data_new=data_new(:,use_var);
    xxnew_dc = dummyvar(wx_data_new);
    X_new_data=[ones(panely2jmarsize,1) xxnew_dc(panely2jmarindex,index_dc)];
    jind=(rand(panely2jmarsize,1) <= normcdf(X_new_data * beta_w));
        
    data(panely2jmarindex(jind),jdata)=data_new(panely2jmarindex(jind),jdata);
             
    end 
     
    % 10)-- impute missing values of Y2 due to NMAR IN THE PANEL /W=0--%
    for j=1:length(panelnmar_y2_index)
         
    panely2jnmarsize = sum(panelnmar_y2(:,j));
    panely2jnmarindex= panelindex_2(panelnmar_y2(Ncp+1:Np,j)==1);
    jdata = q+p1+panelnmar_y2_index(j);
         
    phi_y2 = reshape(phi(ph(panely2jnmarindex),jdata,1:d(jdata)), [panely2jnmarsize d(jdata)]);
    data_new=data;

    mat_y2 = [zeros(panely2jnmarsize,1) cumsum(phi_y2,2)];   
    r_y2 = unifrnd(0,1,[panely2jnmarsize,1]);
        for l=1:d(jdata)
         misind2 =r_y2 > mat_y2(:,l) & r_y2 <= mat_y2(:,l+1);
         data_new(panely2jnmarindex(misind2),jdata)=l; 
        end
    

    wx_data_new=data_new(:,use_var);
    xxnew_dc = dummyvar(wx_data_new);
    X_new_data=[ones(panely2jnmarsize,1) xxnew_dc(panely2jnmarindex,index_dc)];
    jind=(rand(panely2jnmarsize,1) <= 1-normcdf(X_new_data * beta_w));
       
    data(panely2jnmarindex(jind),jdata)=data_new(panely2jnmarindex(jind),jdata);
             
    end 
 
    
    % 11)-- impute missing values of W -- %
    wx_data=data(:,use_var);
    
    xx_dc = dummyvar(wx_data);
    X_beta=[ones(Nr,1) xx_dc(Np+1:N,index_dc)];

    W(Np+1:N) = (rand(Nr,1)<=normcdf(X_beta * beta_w)); 
    
%%% (0)--------------- Posterior predictive check --------------------- %%%
    data_rep1=data;

    for j=1:p2
  
    jdata = q+p1+j;
    
    phi_y2_all = reshape(phi(ph,jdata,1:d(jdata)), [N d(jdata)]);        
    mat_y2_all = [zeros(N,1) cumsum(phi_y2_all,2)];
    r_y2_all = unifrnd(0,1,[N,1]);

        for l=1:d(jdata)
            ind =r_y2_all > mat_y2_all(:,l) & r_y2_all <= mat_y2_all(:,l+1);
            data_rep1(ind,jdata)=l; 
        end
    
    end

     % 12) -- save sampled values (after thinning) -- %
    
    if mod(t,thin)==0 && t > burn
       cell_out((t-burn)/thin, :) = sum(bsxfun(@times,phi(:,:,1),nu'));   
       alpha_out((t-burn)/thin) = alpha;
     
       beta_w_out((t-burn)/thin,:) = beta_w;
       pi_out((t-burn)/thin,:) = nu;
       
       kout((t-burn)/thin) = length(unique(ph));

       y1_out((t-burn)/thin,:) = mean(data(:,q+1:q+p1));
       y2_out((t-burn)/thin,:) = mean(data(:,q+p1+1:p));
       w_out((t-burn)/thin) = mean(W(Np+1:N));
    end
    
    % 13) --store multiple completed data sets ---%
      if sum(cut==t)>0
         imp=imp+1;
         impdata_out(((imp-1)*N+1):imp*N,1) = repmat(imp, [N 1]);
         impdata_out(((imp-1)*N+1):imp*N,2:(p+1)) = data;
         impdata_out(((imp-1)*N+1):imp*N,p+2) = W;
          repdata_out_1(((imp-1)*N+1):imp*N,1) = repmat(imp, [N 1]);
         repdata_out_1(((imp-1)*N+1):imp*N,2:(p+1)) = data_rep1;
         repdata_out_1(((imp-1)*N+1):imp*N,p+2) = W_rep1;
      end
        
    
    if mod(t,1000)==0
        t             
    end
    
end  %mcmc end
%%
%%%%-----------save the whole working space into one object------------%%%%
%The saved output will be used by the analysis code:
%code-for-APYN-analysis.m
save output/apyndata-output.mat